from memory_profiler import profile, LogFile
import sys
import os
import logging

from mmd import mix_rbf_mmd2, rbf_mmd2, batched_rbf_mmd2
from copy import deepcopy

from utils import cwd, set_deterministic, save_results
from data_utils import huber, assign_data

from data_utils import _get_loader, _get_MNIST32, _get_CIFAR10, _get_credit_card, _get_TON
from utils import CNN_Net_32, LogisticRegression, get_trained_feature_extractor, get_accuracy
import numpy as np
import torch
import torchvision.models as models
import torch.nn as nn


from tqdm import tqdm
import argparse
from os.path import join as oj


@profile
def get_MMD_values(D_Xs, D_Ys, V_X, V_Y, sigma_list=[1,2,5,10], batch_size=1024, device=torch.device('cuda')):
    neg_MMDs = []
    for D_X in D_Xs:
        min_len = min(len(D_X), len(V_X))
        MMD2 = batched_rbf_mmd2(D_X[:min_len], V_X[:min_len], sigma_list, device=device, batch_size=batch_size) # use a batched version of rbf_mmd2 to avoid OOM error
        neg_MMDs.append(-torch.sqrt(max(torch.tensor(1e-6), MMD2)).item())
    return neg_MMDs    

def get_extracted(model, loader, device):
    model = model.to(device)
    D_X = []
    model.eval()
    with torch.no_grad():
        for i, (batch_data, batch_target) in enumerate(loader):
            batch_data, batch_target = batch_data.to(device), batch_target.to(device)
            outputs = model(batch_data)

            D_X.append(outputs)

    return torch.cat(D_X)

from sklearn.utils import resample


def run_exp(dataset, N, size, Q_dataset, not_huber=False, heterogeneity='normal', n_trials=1):

    if dataset == 'MNIST':
        X_train, y_train, X_test, y_test = _get_MNIST32()
        # MNIST
        model = CNN_Net_32()
        # latent dimension d
        d = 10
        epochs = 10

    elif dataset == 'CIFAR10':
        X_train, y_train, X_test, y_test = _get_CIFAR10()
        # CIFAR10
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(512, 10)
        d = 10
        epochs = 50
    elif dataset == 'CreditCard':
        X_train, y_train, X_test, y_test = _get_credit_card()
        epochs = 30
        model = LogisticRegression(7, 2)
        d = 7
    elif dataset == 'TON':
        X_train, y_train, X_test, y_test = _get_TON()
        epochs = 30
        model = LogisticRegression(22, 8)
        d = 22
    else:
        raise NotImplementedError(f"P = {dataset} is not implemented.")

    trainloader, testloader = _get_loader(X_train, y_train), _get_loader(X_test, y_test)
    feature_extractor = get_trained_feature_extractor(model, trainloader, testloader, epochs=0)

    values_over_trials, values_hat_over_trials = [], []
    for _ in tqdm(range(n_trials), desc =f'A total of {n_trials} trials.'):
        # raw data
        D_Xs, D_Ys, V_X, V_Y, labels = assign_data(N, size, dataset, Q_dataset, not_huber, heterogeneity)

        # extract features for MMD
        for i, (D_X, D_Y) in enumerate(zip(D_Xs, D_Ys)):
            loader = _get_loader(D_X, D_Y)
            D_Xs[i] = get_extracted(feature_extractor, loader, device)

        MMD_values_hat = get_MMD_values(D_Xs, None, torch.cat(D_Xs), None)
        values_hat_over_trials.append(MMD_values_hat)
    return



from memory_profiler import memory_usage
from collections import defaultdict
import pandas as pd
import time

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='Process which dataset to run')
    # parser.add_argument('-N', '--N', help='Number of data vendors.', type=int, required=True, default=5)
    # parser.add_argument('-m', '--size', help='Size of sample datasets.', type=int, required=True, default=1500)
    parser.add_argument('-P', '--dataset', help='Pick the dataset to run.', type=str, required=True)
    parser.add_argument('-Q', '--Q_dataset', help='Pick the Q dataset.', type=str, required=False, choices=['normal', 'EMNIST', 'FaMNIST', 'CIFAR100' , 'CreditCard', 'UGR16'])
    # parser.add_argument('-n_t', '--n_trials', help='Number of trials.', type=int, default=5)
    # parser.add_argument('-nh', '--not_huber', help='Not with huber, meaning with other types of specified heterogeneity.', action='store_true')
    # parser.add_argument('-het', '--heterogeneity', help='Type of heterogeneity.', type=str, default='normal', choices=['normal', 'label', 'classimbalance'])

    # parser.add_argument('-nocuda', dest='cuda', help='Not to use cuda even if available.', action='store_false')
    # parser.add_argument('-cuda', dest='cuda', help='Use cuda if available.', action='store_true')

    cmd_args = parser.parse_args()
    print(cmd_args)

    set_deterministic()
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    dataset = cmd_args.dataset
    Q_dataset = cmd_args.Q_dataset

    res_dir = oj('rebuttal', 'noGen', dataset)
    os.makedirs(res_dir, exist_ok=True)
    
    print(f"----- Running scalability experiments for without fitting Gen -----")
    memory_data_dict = defaultdict(list)
    time_data_dict = defaultdict(list)
    cuda_memory_data_dict = defaultdict(list)

    sizes = [3000, 5000, 7000, 9000, 10000, 15000, 20000, 30000, 40000, 50000]
    time_data_dict["N-by-size"] = sizes
    memory_data_dict["N-by-size"] = sizes
    cuda_memory_data_dict["N-by-size"] = sizes
    for N in [10, 20, 30, 40, 50]:
        for size in sizes:
            logger = logging.getLogger('mem-time_profile_log')
            logger.setLevel(logging.DEBUG)

            # create file handler which logs even debug messages
            fh = logging.FileHandler(oj(res_dir, f"mem-time_profile-N{N}-m{size}.log"))
            fh.setLevel(logging.DEBUG)

            # create formatter
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            fh.setFormatter(formatter)

            # add the handlers to the logger
            logger.addHandler(fh)
            sys.stdout = LogFile('mem-time_profile_log', reportIncrementFlag=False)

            timestamp  = time.time()
            max_memory_usage = memory_usage((run_exp, (dataset, N, size, Q_dataset), {}), max_usage=True)
            timestamp_  = time.time()

            time_usage = timestamp_ - timestamp
            timestamp = timestamp_

            # print(torch.cuda.memory_summary(device=device))
            cuda_memory_stats = torch.cuda.memory_stats(device=device)
            max_cuda_memory_usage = cuda_memory_stats["allocated_bytes.all.peak"] / 1e6

            print(f"N{N}-m{size}: max_mem {max_memory_usage}, time {time_usage}, cuda_max_mem {max_cuda_memory_usage}. ")

            time_data_dict[N].append(time_usage)
            memory_data_dict[N].append(max_memory_usage)
            cuda_memory_data_dict[N].append(max_cuda_memory_usage)

    print(memory_data_dict)
    print(time_data_dict)
    print(cuda_memory_data_dict)

    df = pd.DataFrame(memory_data_dict)
    df.to_csv(oj(res_dir, 'mem-results.csv'), index=False)
    df.to_latex(oj(res_dir, 'mem-results.tex'), index=False)
    print(df)

    df = pd.DataFrame(time_data_dict)
    df.to_csv(oj(res_dir, 'time-results.csv'), index=False)
    df.to_latex(oj(res_dir, 'time-results.tex'), index=False)
    print(df)

    df = pd.DataFrame(cuda_memory_data_dict)
    df.to_csv(oj(res_dir, 'cuda-mem-results.csv'), index=False)
    df.to_latex(oj(res_dir, 'cuda-mem-results.tex'), index=False)
    print(df)